#!/usr/bin/env python

# This implement the sagemaker serving service shell.  It starts nginx and gunicorn.
# Parameter               Env Var                      Default Value
# number of workers       BENTO_SERVER_TIMEOUT         60s
# timeout                 GUNICORN_WORKER_COUNT        number of cpu cores / 2 + 1
# api name                API_NAME                     None

import subprocess
import os
import signal
import sys
from dependency_injector.wiring import inject, Provide

from bentoml.configuration.containers import BentoMLConfiguration, BentoMLContainer


def sigterm_handler(nginx_pid, gunicorn_pid):
    try:
        os.kill(nginx_pid, signal.SIGQUIT)
    except OSError:
        pass
    try:
        os.kill(gunicorn_pid, signal.SIGTERM)
    except OSError:
        pass

    sys.exit(0)


@inject
def _serve(
    bento_server_timeout: int = Provide[BentoMLContainer.config.api_server.timeout],
    bento_server_workers: int = Provide[BentoMLContainer.api_server_workers],
):
    # link the log streams to stdout/err so they will be logged to the container logs
    subprocess.check_call(['ln', '-sf', '/dev/stdout', '/var/log/nginx/access.log'])
    subprocess.check_call(['ln', '-sf', '/dev/stderr', '/var/log/nginx/error.log'])

    nginx = subprocess.Popen(['nginx', '-c', '/bento/nginx.conf'])
    gunicorn_app = subprocess.Popen(
        [
            'gunicorn',
            '--timeout',
            str(bento_server_timeout),
            '-k',
            'gevent',
            '-b',
            'unix:/tmp/gunicorn.sock',
            '-w',
            str(bento_server_workers),
            'wsgi:app',
        ]
    )
    signal.signal(
        signal.SIGTERM, lambda a, b: sigterm_handler(nginx.pid, gunicorn_app.pid)
    )

    pids = {nginx.pid, gunicorn_app.pid}
    while True:
        pid, _ = os.wait()
        if pid in pids:
            break
    sigterm_handler(nginx.pid, gunicorn_app.pid)
    print('Inference server exiting')


if __name__ == '__main__':
    container = BentoMLContainer()
    config = BentoMLConfiguration()
    if "BENTOML_GUNICORN_TIMEOUT" in os.environ:
        config.override(["api_server", "timeout"], int(os.environ.get("BENTOML_GUNICORN_TIMEOUT")))
    if "BENTOML_GUNICORN_NUM_OF_WORKERS" in os.environ:
        config.override(["api_server", "workers"], int(os.environ.get("BENTOML_GUNICORN_NUM_OF_WORKERS")))
    container.config.from_dict(config.as_dict())

    container.wire(modules=[sys.modules[__name__]])

    _serve()
